"""ML-Ensemble

:author: Sebastian Flennerhag
:copyright: 2017-2018
:license: MIT

Estimator wrappers around base classes.
"""
from .. import config
from .base import BaseParallel, OutputMixin
from .backend import ParallelProcessing
from ..utils.exceptions import ParallelProcessingError, NotFittedError


class EstimatorMixin(object):

    """Estimator mixin

    Mixin class to build an estimator from a :mod:`mlens.parallel` backend
    class. The backend class should be set as the ``_backend`` attribute
    of the estimator during a ``fit`` call via a ``_build`` method. E.g::

        Foo(EstimatorMixin, Learner):

            def __init__(self, ...):

                self._backend = None

            def _build(self):
                self._backend = Learner(...)

    It is recommended to combine :class:`EstimatorMixin` with
    `parallel.base.ParamMixin`.
    """

    def fit(self, X, y, proba=False, refit=True):
        """Fit

        Fit estimator.

        Parameters
        ----------
        X: array of size [n_samples, n_features]
            input data

        y: array of size [n_features,]
            targets

        proba: bool, optional
            whether to fit for later predict_proba calls. Will register number
            of classes to expect in later predict and transform calls.

        refit: bool (default = True)
            Whether to refit already fitted sub-learners.

        Returns
        -------
        self: instance
            fitted estimator.
        """
        if hasattr(self, '_build'):
            self._build()

        run(get_backend(self), 'fit', X, y,
            proba=proba, refit=refit, return_preds=False)
        return self

    def fit_transform(self, X, y, proba=False, refit=True):
        """Fit

        Fit estimator and return cross-validated predictions.

        Parameters
        ----------
        X: array of size [n_samples, n_features]
            input data

        y: array of size [n_features,]
            targets

        proba: bool, optional
            whether to fit for later predict_proba calls. Will register number
            of classes to expect in later predict and transform calls.

        refit: bool (default = True)
            Whether to refit already fitted sub-learners.

        Returns
        -------
        P: array of size [n_samples, n_prediction_features]
            prediction generated by cross-validation.
        """
        if hasattr(self, '_build'):
            self._build()

        return run(get_backend(self), 'fit', X, y, proba=proba,
                   refit=refit, return_preds=True)

    def predict(self, X, proba=False):
        """Predict

        Predict using full-fold estimator (fitted on all data).

        Parameters
        ----------
        X: array of size [n_samples, n_features]
            input data

        proba: bool, optional
            whether to predict class probabilities

        Returns
        -------
        P: array of size [n_samples, n_prediction_features]
            prediction with full-fold estimator.
        """
        if hasattr(self, '__fitted__'):
            if not self.__fitted__:
                raise NotFittedError(
                    "Instance not fitted (with current params).")

        return run(
            get_backend(self), 'predict', X, proba=proba, return_preds=True)

    def transform(self, X, proba=False):
        """Transform

        Use cross-validated estimators to generate predictions.

        Parameters
        ----------
        X: array of size [n_samples, n_features]
            input data

        proba: bool, optional
            whether to predict class probabilities

        Returns
        -------
        P: array of size [n_samples, n_prediction_features]
            prediction generated by cross-validation.
        """
        if hasattr(self, '__fitted__'):
            if not self.__fitted__:
                raise NotFittedError(
                    "Instance not fitted (with current params).")

        return run(get_backend(self), 'transform', X, proba=proba,
                   return_preds=True)


def get_backend(instance):
    """Check whether backend exists and return"""
    _backend = getattr(instance, '_backend', None)
    if _backend:
        instance = _backend

    if issubclass(instance.__class__, BaseParallel):
        return instance

    raise ParallelProcessingError(
        "The estimator does not have a backend. Cannot process.")


def set_flags(backend, flags):
    """Set proba on backend"""
    resets = list()
    if 'layer' in backend.__class__.__name__.lower():
        updates = [backend] + backend.learners
    elif 'group' in backend.__class__.__name__.lower():
        updates = backend.learners
    elif not isinstance(backend, list):
        updates = [backend]
    else:
        updates = backend

    for obj in updates:
        _res = dict()
        for key, val in flags.items():
            if hasattr(obj, key):
                _res[key] = getattr(obj, key)
                setattr(obj, key, val)
        resets.append((obj, _res))
    return resets


def reset_flags(resets):
    """Reset proba on backend"""
    for obj, _res in resets:
        for k, v in _res.items():
            setattr(obj, k, v)


def set_predict(kwargs):
    """Set attr argument and proba"""
    out = dict()
    proba = kwargs.pop('proba', False)
    if proba:
        out['proba'] = proba
        out['attr'] = 'predict_proba' if proba else 'predict'
    return out


def set_output(kwargs, job, map):
    """Set the __no_output__ flag"""
    if 'return_preds' in kwargs:
        if kwargs['return_preds']:
            __no_output__ = False
        else:
            __no_output__ = True
    else:
        __no_output__ = job == 'fit'
        kwargs['return_preds'] = job != 'fit'

    if not map:
        # Need to ensure outputs always generated for stacking
        __no_output__ = False

    return __no_output__


def run(caller, job, X, y=None, map=True, **kwargs):
    """Utility for running a ParallelProcessing job on a set of callers.

    Run is a utility mapping for setting up a ParallelProcessing job and
    executing across a set of callers. By default run executes::

        out = mgr.map(caller, job, X, y, **kwargs)

    :func:`run` handles temporary parameter changes, for instance running
    a learner with ``proba=True`` that has ``proba=False`` as default.
    Similarly, instances destined to not produce output can be forced to
    yield predictions by passing ``return_preds=True`` as a keyword argument.

    .. note:: To run a learner with a ``preprocessing`` dependency, the
        instances need to be wrapped in a :class:`Group` ::

            run(Group(learner, transformer), 'predict', X, y)

    Parameters
    ----------
    caller: instance, list
        A runnable instance, or a list of instances.

    job: str
        type of job to run. One of ``'fit'``, ``'transform'``, ``'predict'``.

    X: array-like
        input

    y: array-like, optional
        targets

    map: bool (default=True)
        whether to run a :func:`ParallelProcessing.map` job. If ``False``,
        will instead run a :func:`ParallelProcessing.stack` job.

    **kwargs: optional
        Keyword arguments. :func:`run` searches for
        ``proba`` and ``return_preds`` to temporarily update callers to run
        desired job and return desired output. Other ``kwargs`` are passed
        to either ``map`` or ``stack``.
    """
    flags = set_predict(kwargs)
    flags['__no_output__'] = set_output(kwargs, job, map)

    resets = set_flags(caller, flags)

    try:
        verbose = max(getattr(caller, 'verbose', 0) - 4, 0)
        _backend = getattr(caller, 'backend', config.get_backend())
        n_jobs = getattr(caller, 'n_jobs', -1)
        with ParallelProcessing(_backend, n_jobs, verbose) as mgr:
            if map:
                out = mgr.map(caller, job, X, y, **kwargs)
            else:
                out = mgr.stack(caller, job, X, y, **kwargs)
    finally:
        reset_flags(resets)
    return out
